import os
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from config import Config
from feature_engineering.data_engineering import data_engineer_benchmark, span_data_2d, span_data_3d
import logging
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import sys
import pickle
import dgl
from scipy.io import loadmat
import yaml

logger = logging.getLogger(__name__)
# sys.path.append("..")


def parse_args():
    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter,
                            conflict_handler='resolve')
    parser.add_argument("--method", default=str)  # specify which method to use
    parser.add_argument('--gpu_id', type=int, default=0, help='GPU device ID')
    parser.add_argument('--labeled_ratio', type=float, default=0.3,
                    help='Ratio of labeled data to keep in training set (default: 0.3)')
    parser.add_argument('--mu', type=float, default=2.0,
                    help='Weight for pseudo label loss')
    parser.add_argument('--mu_rampup', action='store_true', default=True,
                    help='Whether to use rampup for mu')
    parser.add_argument('--consistency_rampup', type=float, default=None,
                    help='Rampup length for consistency weight')
    parser.add_argument('--num_labeled', type=int, default=None,
                    help='Number of labeled data to keep in training set')
    parser.add_argument('--config', type=str, default=None,
                    help='Path to configuration file (overrides default method config)')
    parser.add_argument('--pgd_nums', type=int, default=50,
                    help='Number of PGD iterations')
    parser.add_argument('--k_percent', type=int, default=10,
                    help='Number of PGD iterations')
    parser.add_argument('--gamma_focal', type=float, default=2,
                    help='Number of PGD iterations')
    parser.add_argument('--gamma_ga', type=float, default=0.5,
                    help='Number of PGD iterations')
    parser.add_argument('--gamma_grad', type=float, default=1,
                    help='Number of PGD iterations')
    parser.add_argument('--alpha', type=float, default=0.05,
                    help='Number of PGD iterations')
    parser.add_argument('--min_class_factor', type=float, default=3.5,
                    help='Number of PGD iterations')
    parser.add_argument('--cluster_temperature', type=float, default=0.8,
                    help='Number of PGD iterations')
    parser.add_argument('--temperature_c', type=float, default=0.5,
                    help='Number of PGD iterations')
    gpu_id = vars(parser.parse_args())['gpu_id']
    method = vars(parser.parse_args())['method']  # dict
    labeled_ratio = vars(parser.parse_args())['labeled_ratio']
    mu = vars(parser.parse_args())['mu']
    mu_rampup = vars(parser.parse_args())['mu_rampup']
    consistency_rampup = vars(parser.parse_args())['consistency_rampup']
    num_labeled = vars(parser.parse_args())['num_labeled']
    config_path = vars(parser.parse_args())['config']
    pgd_nums = vars(parser.parse_args())['pgd_nums']
    k_percent = vars(parser.parse_args())['k_percent']
    gamma_focal = vars(parser.parse_args())['gamma_focal']
    gamma_ga = vars(parser.parse_args())['gamma_ga']
    gamma_grad = vars(parser.parse_args())['gamma_grad']
    alpha = vars(parser.parse_args())['alpha']
    min_class_factor = vars(parser.parse_args())['min_class_factor']
    cluster_temperature = vars(parser.parse_args())['cluster_temperature']
    temperature_c = vars(parser.parse_args())['temperature_c']
    # 如果指定了配置文件，则使用指定的文件
    if config_path is not None:
        yaml_file = config_path
    # 否则根据方法选择默认配置文件
    # if method in ['']:
    #     yaml_file = "config/base_cfg.yaml"
    elif method in ['mcnn']:
        yaml_file = "config/mcnn_cfg.yaml"
    elif method in ['stan']:
        yaml_file = "config/stan_cfg.yaml"
    elif method in ['stan_2d']:
        yaml_file = "config/stan_2d_cfg.yaml"
    elif method in ['stagn']:
        yaml_file = "config/stagn_cfg.yaml"
    elif method in ['hogrl_GradConf']:
        yaml_file = "config/hogrl_cfg.yaml"
    elif method in ['gtan_dual']:
        yaml_file = "config/gtan_cfg.yaml"
    elif method in ['rgtan_dual']:
        yaml_file = "config/rgtan_cfg.yaml"
    else:
        raise NotImplementedError("Unsupported method.")

    # config = Config().get_config()
    with open(yaml_file) as file:
        args = yaml.safe_load(file)
    args['gpu_id'] = gpu_id
    args['method'] = method
    args['labeled_ratio'] = labeled_ratio
    args['mu'] = mu
    args['mu_rampup'] = mu_rampup
    args['consistency_rampup'] = consistency_rampup
    args['num_labeled'] = num_labeled
    args['pgd_nums'] = pgd_nums
    args['k_percent'] = k_percent
    args['gamma_focal'] = gamma_focal
    args['gamma_ga'] = gamma_ga
    args['gamma_grad'] = gamma_grad
    args['alpha'] = alpha
    args['min_class_factor'] = min_class_factor
    args['cluster_temperature'] = cluster_temperature
    args['temperature_c'] = temperature_c
    return args


def base_load_data(args: dict):
    # load S-FFSD dataset for base models
    data_path = "data/S-FFSD.csv"
    feat_df = pd.read_csv(data_path)
    train_size = 1 - args['test_size']
    method = args['method']
    # for ICONIP16 & AAAI20
    if args['method'] == 'stan':
        if os.path.exists("data/tel_3d.npy"):
            return
        features, labels = span_data_3d(feat_df)
    else:
        if os.path.exists("data/tel_2d.npy"):
            return
        features, labels = span_data_2d(feat_df)
    num_trans = len(feat_df)
    trf, tef, trl, tel = train_test_split(
        features, labels, train_size=train_size, stratify=labels, shuffle=True)
    trf_file, tef_file, trl_file, tel_file = args['trainfeature'], args[
        'testfeature'], args['trainlabel'], args['testlabel']
    np.save(trf_file, trf)
    np.save(tef_file, tef)
    np.save(trl_file, trl)
    np.save(tel_file, tel)
    return


def main(args):
    if args['method'] == 'mcnn':
        from methods.mcnn.mcnn_main import mcnn_main
        base_load_data(args)
        mcnn_main(
            args['trainfeature'],
            args['trainlabel'],
            args['testfeature'],
            args['testlabel'],
            epochs=args['epochs'],
            batch_size=args['batch_size'],
            lr=args['lr'],
            device=args['device']
        )
    elif args['method'] == 'stan_2d':
        from methods.stan.stan_2d_main import stan_main
        base_load_data(args)
        stan_main(
            args['trainfeature'],
            args['trainlabel'],
            args['testfeature'],
            args['testlabel'],
            mode='2d',
            epochs=args['epochs'],
            batch_size=args['batch_size'],
            attention_hidden_dim=args['attention_hidden_dim'],
            lr=args['lr'],
            device=args['device']
        )
    elif args['method'] == 'stan':
        from methods.stan.stan_main import stan_main
        base_load_data(args)
        stan_main(
            args['trainfeature'],
            args['trainlabel'],
            args['testfeature'],
            args['testlabel'],
            mode='3d',
            epochs=args['epochs'],
            batch_size=args['batch_size'],
            attention_hidden_dim=args['attention_hidden_dim'],
            lr=args['lr'],
            device=args['device']
        )

    elif args['method'] == 'stagn':
        from methods.stagn.stagn_main import stagn_main, load_stagn_data
        features, labels, g = load_stagn_data(args)
        stagn_main(
            features,
            labels,
            args['test_size'],
            g,
            mode='2d',
            epochs=args['epochs'],
            attention_hidden_dim=args['attention_hidden_dim'],
            lr=args['lr'],
            device=args['device']
        )
    elif args['method'] == 'gtan':
        from methods.gtan.gtan_main import gtan_main, load_gtan_data
        feat_data, labels, train_idx, test_idx, g, cat_features = load_gtan_data(
            args['dataset'], args['test_size'])
        gtan_main(
            feat_data, g, train_idx, test_idx, labels, args, cat_features)
    elif args['method'] == 'rgtan':
        from methods.rgtan.rgtan_main import rgtan_main, loda_rgtan_data
        feat_data, labels, train_idx, test_idx, g, cat_features, neigh_features = loda_rgtan_data(
            args['dataset'], args['test_size'])
        rgtan_main(feat_data, g, train_idx, test_idx, labels, args,
                   cat_features, neigh_features, nei_att_head=args['nei_att_heads'][args['dataset']])
    elif args['method'] == 'hogrl':
        from methods.hogrl.hogrl_main import hogrl_main
        hogrl_main(args)
    elif args['method'] == 'hogrl_GradConf':
        from methods.hogrl.hogrl_main_GradConf import hogrl_main
        hogrl_main(args)
    elif args['method'] == 'gtan_dual':
        from methods.gtan.gtan_main_dual import gtan_main_dual
        gtan_main_dual(args)
    elif args['method'] == 'rgtan_dual':
        from methods.rgtan.rgtan_main_dual import rgtan_main_dual
        rgtan_main_dual(args)
    
    else:
        raise NotImplementedError("Unsupported method. ")


if __name__ == "__main__":
    main(parse_args())
